from torch import nn
from MFCC_LCNN.lfcc_model_LCNN import LightCNN_29Layers as MFCCLCNN

class Model_zoo(nn.Module):
    def __init__(self, mt, device=None, parser1=None):
        super(Model_zoo, self).__init__()
        self.mt = mt
        self.device = device
        self.n_m = {
            # "shuffleNetV2": models.shufflenet_v2_x0_5(False, True),
            #"DensenNet121": models.densenet121(False, True),
            # "MobileNetV2": models.mobilenet_v2(False, True),
            # "ResNet18": models.resnet18(False, True),
            "lcnn_mfcc": MFCCLCNN(),
        }
        self.model = self.n_m[self.mt]
        self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(1000, 2)


    # x input shape = (64600)
    def forward(self, x, Freq_aug=False):
        x = x.unsqueeze(dim=1)# (1,mmmm,num_filter)
        # 将1d矩阵变成3D矩阵
        x = self.conv(x)#(1,1,mmm,num_filter)
        x = self.model(x)
        # 将1000分类变成2分类
        x = self.fc(x)

        return x

from ASVBaselineTool.DEMONEED import MFCCfeature
def demo():

    USE_FEATURE = MFCCfeature
    # 测试lcnn的网络畅通性
    feaure_net_map = {
        MFCCfeature: "lcnn_mfcc",
        #LFCCfeature: "lcnn_lfcc",
        #CQCCfeature: "lcnn_cqcc",
        #SPECfeature: "lcnn_spec",
    }
    for feature in [USE_FEATURE]:
        net = Model_zoo(feaure_net_map[feature])
        print(net(feature).shape)
        print("lcnn 模型 {} 特征输出正常".format(feaure_net_map[feature]))


if __name__=="__main__":
    demo()